from typing import List, Union, Tuple, Optional
from tools.codegen.model import (Type, BaseTy, BaseType, OptionalType,
                                 ListType, OperatorName, FunctionSchema,
                                 Return, TensorOptionsArguments, Argument)
from tools.codegen.api.types import (CType, BaseCppType, BaseCType, OptionalCType,
                                     NamedCType, deviceT, layoutT,
                                     VectorCType, boolT, longT, doubleT, ListCType, stringT,
                                     scalarT, scalarTypeT, memoryFormatT)

valueT = BaseCppType('torch::lazy', 'Value')
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType('torch::lazy', 'Value')

def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
    """
    This function takes a type from NativeFunctions and converts it for use with
    lazy tensor codegen.

    Type conversion for lazy currently consists of
     (1) changing at::Tensors into lazy::Values
     (2) wrapping everything in a BaseCType
     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)

    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'

    This is incomplete- there are assertions in places that it's expected to need to add
    more types as the codegen is used with more operators.
    """
    if isinstance(typ, BaseType):
        if typ.name == BaseTy.Tensor:
            return BaseCType(valueT)
        elif typ.name == BaseTy.Scalar:
            # at::scalar has special handling,
            # and is wrapped in an lazy::Value just like at::tensor
            return BaseCType(valueT)
        elif typ.name == BaseTy.ScalarType:
            return BaseCType(scalarTypeT)
        elif typ.name == BaseTy.int:
            return BaseCType(longT)
        elif typ.name == BaseTy.bool:
            return BaseCType(boolT)
        elif typ.name == BaseTy.float:
            return BaseCType(doubleT)
        elif typ.name == BaseTy.str:
            return BaseCType(stringT)
        elif typ.name == BaseTy.Device:
            return BaseCType(deviceT)
        elif typ.name == BaseTy.Layout:
            return BaseCType(layoutT)
        elif typ.name == BaseTy.MemoryFormat:
            return BaseCType(memoryFormatT)
        else:
            raise AssertionError(f"TODO add support for type {repr(typ)}")
    elif isinstance(typ, OptionalType):
        return OptionalCType(process_ir_type(typ.elem))
    elif isinstance(typ, ListType):
        if str(typ.elem) == 'Tensor?':
            # TODO(whc) is this actually correct? or should it use a Vector like above
            return ListCType(OptionalCType(BaseCType(valueT)))
        elif str(typ.elem) == 'Tensor':
            # this is a TensorList which comes in from GetTensorList as a Value
            return BaseCType(tensorListValueT)
        else:
            return VectorCType(process_ir_type(typ.elem))
    else:
        raise AssertionError(f"unrecognized type {repr(typ)}")


def isValueType(typ: CType) -> bool:
    """
    Given a type, determine if it is a Value-like type.  This is equivalent to
    being Tensor-like, but assumes the type has already been transformed.
    """
    if isinstance(typ, BaseCType):
        # I am regretting my naming conventions, but now we are wrapping at::scalar in
        # lazy value, while preserving other 'scalar' types as scalars in the IR
        return typ.type == valueT or typ.type == scalarT
    elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
        return isValueType(typ.elem)
    return False

def isWrappedScalarType(typ: Type) -> bool:
    """
    Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
    Since we literally change the type from scalarT to valueT, information is lost.
    This function helps build a list of wrapped scalars to save that information
    """
    if isinstance(typ, BaseType):
        # I am regretting my naming conventions, but now we are wrapping at::scalar in
        # lazy value, while preserving other 'scalar' types as scalars in the IR
        return typ.name == BaseTy.Scalar
    elif isinstance(typ, (OptionalType, ListType)):
        return isWrappedScalarType(typ.elem)
    return False

def isGeneratorType(typ: Type) -> bool:
    if isinstance(typ, BaseType):
        return typ.name == BaseTy.Generator
    elif isinstance(typ, (OptionalType)):
        return isGeneratorType(typ.elem)
    return False

class LazyArgument:
    name: str
    orig_type: Type
    lazy_type_: Optional[CType]
    is_wrapped_scalar: bool
    is_generator: bool

    # true if this argument is or contains a lazy IR value
    is_lazy_value: bool

    def __init__(self, arg: Argument):
        self.name = arg.name
        self.orig_type = arg.type
        self.is_generator = isGeneratorType(arg.type)
        if self.is_generator:
            assert isinstance(arg.type, OptionalType), "We expect all generators are optional since currently they are"
            # there is no handling for generators in TorchScript IR (or XLA)
            # so we fall back to eager if the (optional)generator has value, and otherwise
            # its null and safe to exclude from lazy IR
            self.lazy_type_ = None
        else:
            self.lazy_type_ = process_ir_type(arg.type)
        self.is_wrapped_scalar = isWrappedScalarType(arg.type)

        self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)

    @property
    def lazy_type(self) -> CType:
        assert self.lazy_type_ is not None, f"Attempted to access lazy_type for invalid argument {self.name}"
        return self.lazy_type_

# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
class LazyIrSchema:
    # The name of the operator this function schema describes.
    name: 'OperatorName'

    positional_args: Tuple[LazyArgument, ...]
    keyword_args: Tuple[LazyArgument, ...]

    # TODO: Need to handle collisions with argument names at some point
    returns: Tuple['Return', ...]

    # if this schema has a Generator arg, list its orig ctype/name but don't
    # build a LazyArgument since lazy IR doesn't support it
    generator_arg: Optional[NamedCType] = None

    def __init__(self, func: FunctionSchema):

        positional_args = []
        for arg_field in ["pre_self_positional",
                          "self_arg",
                          "post_self_positional"]:
            if arg_field == "self_arg" and func.arguments.self_arg is not None:
                arg = getattr(func.arguments, "self_arg").argument
                positional_args.append(LazyArgument(arg))
            elif getattr(func.arguments, arg_field) is not None:
                positional_args.extend([
                    LazyArgument(arg) for arg in getattr(func.arguments, arg_field)])
        self.positional_args = tuple(positional_args)

        keyword_args = []
        for arg_field in ["pre_tensor_options_kwarg_only",
                          "tensor_options",
                          "post_tensor_options_kwarg_only",
                          "out"]:
            curr_args = getattr(func.arguments, arg_field)
            if curr_args is not None:
                if isinstance(curr_args, TensorOptionsArguments):
                    curr_args = curr_args.all()
                for arg in curr_args:
                    if isGeneratorType(arg.type):
                        assert self.generator_arg is None, "We expect there is only one generator arg"
                        self.generator_arg = NamedCType(arg.name, arg.type)
                keyword_args.extend([LazyArgument(arg) for arg in curr_args])
        self.keyword_args = tuple(keyword_args)
        self.name = func.name
        self.returns = func.returns

    @property
    def node_name(self) -> str:
        """
        Return camel-case version of op in node.

        Note: This function also appends any `overload_name` in the operation.
        For example, if the op is `bitwise_and.Tensor`, the returned name
        will be `BitwiseAndTensor`.
        """
        op_name = f"{self.name.name}_{self.name.overload_name}".lower()
        return "".join(word.capitalize() or "" for word in op_name.split("_"))

    @property
    def aten_name(self) -> str:
        return f"{self.name.name}"

    @property
    def base_name(self) -> str:
        return f"{self.name.name.base}"

    def filtered_args(self, positional: bool = True, keyword: bool = True,
                      values: bool = True, scalars: bool = True, generator: bool = False) -> List[LazyArgument]:
        # This function maintains the sorted order of arguments but provides different filtered views.
        # Some parts of the code care about kwargs vs args (TS lowerings),
        # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
        # Generators are special cased, as they are needed for fallback/shape-inference but not supported
        # in TS lowerings and therefore also omitted from lazy IR.
        args: List[LazyArgument] = []
        if positional:
            args.extend(self.positional_args)
        if keyword:
            args.extend(self.keyword_args)

        if values and scalars and generator:
            return args
        elif values and scalars:
            return [a for a in args if not a.is_generator]
        elif values:
            return [a for a in args if a.is_lazy_value]
        elif scalars:
            return [a for a in args if not a.is_lazy_value and (generator or not a.is_generator)]

        return []

    @property
    def positional_values(self) -> List[LazyArgument]:
        return self.filtered_args(positional=True, keyword=False, values=True, scalars=False)

    @property
    def positional_scalars(self) -> List[LazyArgument]:
        return self.filtered_args(positional=True, keyword=False, values=False, scalars=True)

    @property
    def keyword_values(self) -> List[LazyArgument]:
        return self.filtered_args(positional=False, keyword=True, values=True, scalars=False)

    @property
    def keyword_scalars(self) -> List[LazyArgument]:
        return self.filtered_args(positional=False, keyword=True, values=False, scalars=True)
